#include "database_workerpool.h"

#include <mysqld_error.h>

#include "log.h"
#include "report.h"
#include "pc_queue.h"
#include "database_mysql_head.h"
#include "database_adhocstatement.h"
#include "database_preparedstatement_mysql.h"
#include "database_preparedstatement.h"
#include "database_query_result.h"
#include "database_query_holder.h"
#include "database_query_callback.h"
#include "database_transaction.h"
#include "database_sql_op.h"
#include "database_mysql_connection.h"

#include "connection_test.h"

namespace afcore {
namespace database {

#define MIN_MYSQL_SERVER_VERSION 50700u
#define MIN_MYSQL_CLIENT_VERSION 50700u

class CPingOperation
  : public CSqlOperation {
  bool Execute() override {
    conn_->Ping();
    return true;
  }
};

template<typename T>
CDatabaseWorkerPool<T>::CDatabaseWorkerPool()
  : q_(new CPcQueue<CSqlOperation*>()) {
  DBG_FATAL(mysql_thread_safe(), "Used MySQL library isn't thread-safe.");
  DBG_FATAL(mysql_get_client_version() >= MIN_MYSQL_CLIENT_VERSION, "Asfinger does not support MySQL versions below 5.7");
  DBG_FATAL(mysql_get_client_version() == MYSQL_VERSION_ID, "Used MySQL library version (%s) does not match the version used to compile Asfinger (%s). Search on forum for TCE00011.",
          mysql_get_client_info(), MYSQL_SERVER_VERSION);
}

template<typename T>
CDatabaseWorkerPool<T>::~CDatabaseWorkerPool() {
  q_->Cancel();
}

template<typename T>
void CDatabaseWorkerPool<T>::SetConnectionInfo(const std::string &conn_info, uint8_t async_threads, uint8_t synch_threads) {
  conn_info_ = std::make_unique<SMysqlConnectionInfo>(conn_info);

  async_threads_ = async_threads;
  synch_threads_ = synch_threads;
}

template<typename T>
uint32_t CDatabaseWorkerPool<T>::Open() {
  DBG_FATAL(conn_info_.get(), "Connection info was not set!");

  LOG_INFO("Opening DatabasePool '{}'. Asynchronous connections: {}, synchronous connections: {}.",
    GetDatabaseName(), async_threads_, synch_threads_);

  uint32_t error = OpenConnections(kInternalIndex_Async, async_threads_);

  if (error) {
    return error;
  }

  error = OpenConnections(kInternalIndex_Synch, synch_threads_);

  if (!error) {
    LOG_INFO("DatabasePool '{}' opened successfully. {} total connections running.", GetDatabaseName(),
      (conns_[kInternalIndex_Synch].size() + conns_[kInternalIndex_Synch].size()));
  }

  return error;
}

template<typename T>
void CDatabaseWorkerPool<T>::Close() {
  LOG_INFO("Closing down DatabasePool '{}'.", GetDatabaseName());

  conns_[kInternalIndex_Async].clear();

  LOG_INFO("Asynchronous connections on DatabasePool '{}' terminated.Proceeding with synchronous connections.",
    GetDatabaseName());

  conns_[kInternalIndex_Synch].clear();

  LOG_INFO("All connections on DatabasePool '{}' closed.", GetDatabaseName());
}

template<typename T>
bool CDatabaseWorkerPool<T>::PrepareStatements() {
  for (auto& conns : conns_) {
    for (auto& conn : conns) {
      conn->LockIfReady();
      if (!conn->PrepareStatements()) {
        conn->Unlock();
        Close();
        return false;
      } else {
        conn->Unlock();
      }

      const size_t prepared_size = conn->stmts_.size();
      if (prepared_stmt_size_.size() < prepared_size) {
        prepared_stmt_size_.resize(prepared_size);
      }

      for (size_t i = 0; i < prepared_size; ++i) {
        if (prepared_stmt_size_[i] > 0) {
          continue;
        }

        if(CPreparedStatementMysql* stmt = conn->stmts_[i].get()) {
          const uint32_t param_count = stmt->GetParameterCount();
          DBG_ASSERT(param_count < std::numeric_limits<uint8_t>::max());
          prepared_stmt_size_[i] = static_cast<uint8_t>(param_count);
        }
      }
    }
  }

  return true;
}

template<typename T>
void CDatabaseWorkerPool<T>::Execute(const char *sql) {
  if (IsFormatEmptyOrNull(sql)) {
    return;
  }
  CBasicStatementTask* task = new CBasicStatementTask(sql);
  Enqueue(task);
}

template<typename T>
void CDatabaseWorkerPool<T>::Execute(CPreparedStatement<T> *stmt) {
  CPreparedStatementTask* task = new CPreparedStatementTask(stmt);
  Enqueue(task);
}

template<typename T>
void CDatabaseWorkerPool<T>::DirectExecute(const char *sql) {
  if (IsFormatEmptyOrNull(sql)) {
    return;
  }

  T* conn = GetFreeConnection();
  conn->Execute(sql);
  conn->Unlock();
}

template<typename T>
void CDatabaseWorkerPool<T>::DirectExecute(CPreparedStatement<T> *stmt) {
  T* conn = GetFreeConnection();
  conn->Execute(stmt);
  conn->Unlock();

  delete stmt;
}

template<typename T>
RQueryResultSptr CDatabaseWorkerPool<T>::Query(const char *sql, T *connection) {
  if (!connection) {
    connection = GetFreeConnection();
  }

  CResultSet* result = connection->Query(sql);
  connection->Unlock();
  if (!result || !result->GetRowCount() || !result->NextRow()) {
      delete result;
      return RQueryResultSptr(nullptr);
  }

  return RQueryResultSptr(result);
}

template<typename T>
RPreparedQueryResultSptr CDatabaseWorkerPool<T>::Query(CPreparedStatement<T> *stmt) {
  T* conn = GetFreeConnection();
  CPreparedResultSet* result = conn->Query(stmt);
  conn->Unlock();

  delete stmt;

  if (!result || !result->GetRowCount()) {
    delete result;
    return RPreparedQueryResultSptr(nullptr);
  }

  return RPreparedQueryResultSptr(result);
}

template<typename T>
CQueryCallback CDatabaseWorkerPool<T>::AsyncQuery(const char *sql) {
  CBasicStatementTask* task = new CBasicStatementTask(sql, true);
  RQueryResultFutrue result = task->GetFuture();
  Enqueue(task);
  return CQueryCallback(std::move(result));
}

template<typename T>
CQueryCallback CDatabaseWorkerPool<T>::AsyncQuery(CPreparedStatement<T> *stmt) {
  CPreparedStatementTask* task = new CPreparedStatementTask(stmt, true);
  RPreparedQueryResultFutrue result = task->GetFuture();
  Enqueue(task);
  return CQueryCallback(std::move(result));
}

template<typename T>
RQueryHolderFutrue CDatabaseWorkerPool<T>::DelayQueryHolder(CSqlQueryHolder<T> *holder) {
  CSqlQueryHolderTask* task = new CSqlQueryHolderTask(holder);
  RQueryHolderFutrue  result = task->GetFuture();
  Enqueue(task);
  return result;
}

template<typename T>
RSqlTransaction<T> CDatabaseWorkerPool<T>::BeginTransaction() {
  return std::make_shared<CTransaction<T>>();
}

template<typename T>
void CDatabaseWorkerPool<T>::CommitTransaction(RSqlTransaction<T> transaction) {
#if defined(AFCORE_DEBUG)
  switch(transaction->GetSize()) {
    case 0:
      {
        LOG_DEBUG("Transaction contains 0 queries. Not executing.");
      }
      return;
    case 1:
      {
        LOG_DEBUG("Transaction only holds 1 query, consider removing Transaction context in code.");
      }
      break;
    default:
      break;
  }
#endif

  Enqueue(new CTransactionTask(transaction));
}

template<typename T>
void CDatabaseWorkerPool<T>::DirectCommitTransaction(RSqlTransaction<T> &transaction) {
  T* conn = GetFreeConnection();
  int error_code = conn->ExecuteTransaction(transaction);
  if (!error_code) {
    conn->Unlock();
    return;
  }

  if (ER_LOCK_DEADLOCK == error_code) {
    uint8_t loop_breaker = 5;
    for (uint8_t i = 0; i < loop_breaker; ++i) {
      if (!conn->ExecuteTransaction(transaction)) {
        break;
      }
    }
  }

  transaction->CleanUp();

  conn->Unlock();
}

template<typename T>
void CDatabaseWorkerPool<T>::ExecuteOrAppend(RSqlTransaction<T> &trans, const char *sql) {
  if (!trans) {
    Execute(sql);
  } else {
    trans->Append(sql);
  }
}

template<typename T>
void CDatabaseWorkerPool<T>::ExecuteOrAppend(RSqlTransaction<T> &trans, CPreparedStatement<T> *stmt) {
  if (!trans) {
    Execute(stmt);
  } else {
    trans->Append(stmt);
  }
}

template<typename T>
CPreparedStatement<T> *CDatabaseWorkerPool<T>::GetPreparedStatement(RPreparedStatementIndex index) {
  return new CPreparedStatement<T>(index, prepared_stmt_size_[index]);
}

template<typename T>
void CDatabaseWorkerPool<T>::EscapeString(std::string &str) {
  if (str.empty()) {
    return;
  }

  char* buf = new char[str.size() * 2 + 1];
  EscapeString(buf, str.c_str(), static_cast<uint32_t>(str.size()));
  str = buf;
  delete[] buf;
}

template<typename T>
void CDatabaseWorkerPool<T>::KeepAlive() {
  for (auto& conn : conns_[kInternalIndex_Synch]) {
    if (conn->LockIfReady()) {
      conn->Ping();
      conn->Unlock();
    }
  }
}

template<typename T>
uint32_t CDatabaseWorkerPool<T>::OpenConnections(CDatabaseWorkerPool::EInternalIndex type, uint8_t conns_num) {
  for (uint8_t i = 0; i < conns_num; ++i) {
    auto conn = [&] {
      switch(type) {
        case kInternalIndex_Async:
          return std::make_unique<T>(*conn_info_, q_.get());
        case kInternalIndex_Synch:
          return std::make_unique<T>(*conn_info_);
        default:
          DBG_ABORT();
      }
    }();

    if (uint32_t error = conn->Open()) {
      conns_[type].clear();
      return error;
    } else if (conn->GetServerVersion() < MIN_MYSQL_SERVER_VERSION) {
      LOG_ERROR("TrinityCore does not support MySQL versions below 5.7");
      return 1;
    } else {
      conns_[type].push_back(std::move(conn));
    }
  }

  return 0;
}

template<typename T>
unsigned long CDatabaseWorkerPool<T>::EscapeString(char *to, const char *from, unsigned long length) {
  if (!to || !from || !length) {
    return 0;
  }

  return conns_[kInternalIndex_Synch].front()->EscapeString(to, from, length);
}

template<typename T>
void CDatabaseWorkerPool<T>::Enqueue(CSqlOperation *op) {
  q_->Enqueue(op);
}

template<typename T>
T *CDatabaseWorkerPool<T>::GetFreeConnection() {
  uint8_t i = 0;
  const auto num_conns = conns_[kInternalIndex_Synch].size();
  T* conn = nullptr;
  for (;;) {
    conn = conns_[kInternalIndex_Synch][i++ % num_conns].get();
    if(conn->LockIfReady()) {
      break;
    }
  }

  return conn;
}

template<typename T>
const char* CDatabaseWorkerPool<T>::GetDatabaseName() const {
  return conn_info_->database.c_str();
}

template class AFCORE_DATABASE_API CDatabaseWorkerPool<CTestConnection>;

} // !namespace database
} // !namespace afcore